#!/usr/bin/python3

import argparse
import sys
import os

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch

from models import Generator
from datasets import ImageDataset

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--cuda', action='store_true', help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=2, help='number of cpu threads to use during batch generation')
parser.add_argument('--load_model', type=str,default=None, help='The path of the pretrain model')
opt = parser.parse_args()
print(opt)

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###### Definition of variables ######
# Networks
netGen = Generator(opt.input_nc, opt.output_nc)

checkpoint = torch.load(opt.load_model)
netGen.load_state_dict(checkpoint['Gen_state_dict'])

if opt.cuda:
    netGen.cuda()

# Set model's test mode
netGen.eval()

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_photo = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)

# Dataset loader
transforms_ = [ transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, mode='test'), 
                        batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu)
###################################

###### Testing######

# Create output dirs if they don't exist
if not os.path.exists('output/cartoon_Gen'):
    os.makedirs('output/cartoon_Gen')

if __name__ == '__main__':
    for i, batch in enumerate(dataloader):
        # Set model input
        photo = Variable(input_photo.copy_(batch['input_photo']))

        # Generate output
        fake_cartoon = 0.5*(netGen(photo).data + 1.0)

        # Save image files
        save_image(fake_cartoon, 'output/cartoon_Gen/%04d.png' % (i+1))

    sys.stdout.write('\n')
###################################
